#include "opencl_source_map.hpp" 
namespace MNN { 
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* attention_buf = 
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"#define GLOBAL_SIZE_2_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"#define DEAL_OUTER_SEQLEN_NOT_ALIGN(length) "" if(4 * sl + 3 >= length) {"" temp_3 = (FLOAT4)0;"" }"" if(4 * sl + 2 >= length) {"" temp_2 = (FLOAT4)0;"" }"" if(4 * sl + 1 >= length) {"" temp_1 = (FLOAT4)0;"" }\n"
"#define DEAL_INNER_HEADDIM_NOT_ALIGN(length) "" if(hd * 4 + 3 >= length) {"" temp_0.w = (FLOAT)0;"" temp_1.w = (FLOAT)0;"" temp_2.w = (FLOAT)0;"" temp_3.w = (FLOAT)0;"" }"" if(hd * 4 + 2 >= length) {"" temp_0.z = (FLOAT)0;"" temp_1.z = (FLOAT)0;"" temp_2.z = (FLOAT)0;"" temp_3.z = (FLOAT)0;"" }"" if(hd * 4 + 1 >= length) {"" temp_0.y = (FLOAT)0;"" temp_1.y = (FLOAT)0;"" temp_2.y = (FLOAT)0;"" temp_3.y = (FLOAT)0;"" }\n"
"__kernel void rearrange_qkv(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input_q,//[batch,seqLenQ/4,headNum,headDim,seqLenQ_4]\n"
" __global const FLOAT *input_k,// [batch,seqLenKV/4,headNum/group,headDim,seqLenKV_4]\n"
" __global const FLOAT *input_v,// [batch,seqLenKV/4,headNum/group,headDim,seqLenKV_4]\n"
" __global FLOAT *output_q,// [batch*headNum,ROUND_UP(headDim,mTileHDK),ROUND_UP(seqLenQ,mTileQ)]\n"
" __global FLOAT *output_k,// [batch*headNum/group,ROUND_UP(headDim,mTileHDK),ROUND_UP(seqLenKV,mTileKV)]\n"
" __global FLOAT *output_v,// [batch*headNum/group,ROUND_UP(seqLenKV,mTileKV),ROUND_UP(headDim,mTileHDN)]\n"
" #ifdef SAVE_KV\n"
" __global FLOAT *past_k,// [batch,headNum/group,headDim,seqLenKV_4]\n"
" __global FLOAT *past_v,// [batch,headNum/group,seqLenKV_4,headDim]\n"
" #endif\n"
" __private const int4 tile,// [mTileQ,mTileKV,mTileHDK,mTileHDN]\n"
" __private const int4 shape,// [seqLenQ,seqLenKV,headNum,headDim]\n"
" __private const int4 param,// [group,batch,max_len,past_len]\n"
" __private const int maxLenKV\n"
") {\n"
" const int sl=get_global_id(0); // seqLen/4 : max(seqLenPackQ/4,seqLenPackKV/4)\n"
" const int hd=get_global_id(1); // headDim/4 : max(headDimPackQK/4,headDimPackV/4)\n"
" const int z=get_global_id(2); // batch*headNum\n"
" DEAL_NON_UNIFORM_DIM3(sl,hd,z);\n"
" \n"
" const int seqLenQ=shape.x;\n"
" const int seqLenKV=shape.y;\n"
" const int headNum=shape.z;\n"
" const int headDim=shape.w;\n"
" const int group=param.x;\n"
" const int batch=param.y;\n"
" const int b=z % batch;\n"
" const int hn=z/batch;\n"
" \n"
" const int seqLenQ_4=(seqLenQ+3)/4;\n"
" //const int in_offset_q=(((b*seqLenQ_4+sl)*headNum+hn)*headDim+4*hd)*4;\n"
" const int in_offset_q=(((b*seqLenQ+sl*4)*headNum+hn)*headDim+4*hd);\n"
" const int seqLenPackQ=((seqLenQ+tile.x-1)/tile.x)*tile.x;\n"
" const int headDimPackQK=((headDim+tile.z-1)/tile.z)*tile.z;\n"
" const int out_offset_q=(((b*headNum+hn)*headDimPackQK+hd*4)*seqLenPackQ+sl*4);\n"
" \n"
" if(sl*4<seqLenPackQ && hd*4<headDimPackQK) {\n"
" if(sl*4 >= seqLenQ || hd*4 >= headDim) {\n"
" vstore4((FLOAT4)0,0,output_q+out_offset_q);\n"
" vstore4((FLOAT4)0,0,output_q+out_offset_q+seqLenPackQ);\n"
" vstore4((FLOAT4)0,0,output_q+out_offset_q+2*seqLenPackQ);\n"
" vstore4((FLOAT4)0,0,output_q+out_offset_q+3*seqLenPackQ);\n"
" } else {\n"
" FLOAT4 temp_0=vload4(0,input_q+in_offset_q);\n"
" FLOAT4 temp_1=(sl*4+1 >= seqLenQ) ? (FLOAT4)0 : vload4(0,input_q+in_offset_q+headNum*headDim);\n"
" FLOAT4 temp_2=(sl*4+2 >= seqLenQ) ? (FLOAT4)0 : vload4(0,input_q+in_offset_q+2*headNum*headDim);\n"
" FLOAT4 temp_3=(sl*4+3 >= seqLenQ) ? (FLOAT4)0 : vload4(0,input_q+in_offset_q+3*headNum*headDim);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_INNER_HEADDIM_NOT_ALIGN(headDim)\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenQ)\n"
" #endif\n"
" vstore4((FLOAT4)(temp_0.s0,temp_1.s0,temp_2.s0,temp_3.s0),0,output_q+out_offset_q);\n"
" vstore4((FLOAT4)(temp_0.s1,temp_1.s1,temp_2.s1,temp_3.s1),0,output_q+out_offset_q+seqLenPackQ);\n"
" vstore4((FLOAT4)(temp_0.s2,temp_1.s2,temp_2.s2,temp_3.s2),0,output_q+out_offset_q+2*seqLenPackQ);\n"
" vstore4((FLOAT4)(temp_0.s3,temp_1.s3,temp_2.s3,temp_3.s3),0,output_q+out_offset_q+3*seqLenPackQ);\n"
" }\n"
" }\n"
" \n"
" if(hn >= headNum/group) {\n"
" return;\n"
" }\n"
" \n"
" const int seqLenPackKV=((seqLenKV+tile.y-1)/tile.y)*tile.y;\n"
" const int headDimPackV=((headDim+tile.w-1)/tile.w)*tile.w;\n"
" const int seqLenKV_4=(seqLenKV+3)/4;\n"
" const int in_offset_kv=(((b*seqLenKV+sl*4)*headNum/group+hn)*headDim+4*hd);\n"
" const int past_offset_k=(((b*headNum/group+hn)*headDim+hd*4)*maxLenKV+sl*4);\n"
" const int past_offset_v=(((b*headNum/group+hn)*maxLenKV+sl*4)*headDim+4*hd);\n"
" if(sl*4<seqLenPackKV && hd*4<headDimPackQK) {\n"
" const int out_offset_k=(((b*headNum/group+hn)*headDimPackQK+hd*4)*seqLenPackKV+sl*4);\n"
" if(sl*4 >= seqLenKV || hd*4 >= headDim) {\n"
" vstore4((FLOAT4)0,0,output_k+out_offset_k);\n"
" vstore4((FLOAT4)0,0,output_k+out_offset_k+seqLenPackKV);\n"
" vstore4((FLOAT4)0,0,output_k+out_offset_k+2*seqLenPackKV);\n"
" vstore4((FLOAT4)0,0,output_k+out_offset_k+3*seqLenPackKV);\n"
" } else {\n"
" FLOAT4 temp_0=vload4(0,input_k+in_offset_kv);\n"
" FLOAT4 temp_1=(sl*4+1 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_k+in_offset_kv+headNum*headDim/group);\n"
" FLOAT4 temp_2=(sl*4+2 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_k+in_offset_kv+2*headNum*headDim/group);\n"
" FLOAT4 temp_3=(sl*4+3 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_k+in_offset_kv+3*headNum*headDim/group);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_INNER_HEADDIM_NOT_ALIGN(headDim)\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenKV)\n"
" #endif\n"
" FLOAT4 key0=(FLOAT4)(temp_0.s0,temp_1.s0,temp_2.s0,temp_3.s0);\n"
" FLOAT4 key1=(FLOAT4)(temp_0.s1,temp_1.s1,temp_2.s1,temp_3.s1);\n"
" FLOAT4 key2=(FLOAT4)(temp_0.s2,temp_1.s2,temp_2.s2,temp_3.s2);\n"
" FLOAT4 key3=(FLOAT4)(temp_0.s3,temp_1.s3,temp_2.s3,temp_3.s3);\n"
" vstore4(key0,0,output_k+out_offset_k);\n"
" vstore4(key1,0,output_k+out_offset_k+seqLenPackKV);\n"
" vstore4(key2,0,output_k+out_offset_k+2*seqLenPackKV);\n"
" vstore4(key3,0,output_k+out_offset_k+3*seqLenPackKV);\n"
" \n"
" // pastK\n"
" #ifdef SAVE_KV\n"
" vstore4(key0,0,past_k+past_offset_k);\n"
" vstore4(key1,0,past_k+past_offset_k+maxLenKV);\n"
" vstore4(key2,0,past_k+past_offset_k+2*maxLenKV);\n"
" vstore4(key3,0,past_k+past_offset_k+3*maxLenKV);\n"
" #endif\n"
" }\n"
" \n"
" }\n"
" \n"
" if(sl*4<seqLenPackKV && hd*4<headDimPackV) {\n"
" const int out_offset_v=(((b*headNum/group+hn)*seqLenPackKV+sl*4)*headDimPackV+hd*4);\n"
" if(sl*4 >= seqLenKV || hd*4 >= headDim) {\n"
" vstore4((FLOAT4)0,0,output_v+out_offset_v);\n"
" vstore4((FLOAT4)0,0,output_v+out_offset_v+headDimPackV);\n"
" vstore4((FLOAT4)0,0,output_v+out_offset_v+2*headDimPackV);\n"
" vstore4((FLOAT4)0,0,output_v+out_offset_v+3*headDimPackV);\n"
" } else {\n"
" FLOAT4 temp_0=vload4(0,input_v+in_offset_kv);\n"
" FLOAT4 temp_1=(sl*4+1 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_v+in_offset_kv+headNum*headDim/group);\n"
" FLOAT4 temp_2=(sl*4+2 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_v+in_offset_kv+2*headNum*headDim/group);\n"
" FLOAT4 temp_3=(sl*4+3 >= seqLenKV) ? (FLOAT4)0 : vload4(0,input_v+in_offset_kv+3*headNum*headDim/group);\n"
" #ifdef HEADDIM_LEAVE\n"
" DEAL_INNER_HEADDIM_NOT_ALIGN(headDim)\n"
" #endif\n"
" #ifdef SEQLEN_LEAVE\n"
" DEAL_OUTER_SEQLEN_NOT_ALIGN(seqLenKV)\n"
" #endif\n"
" vstore4(temp_0,0,output_v+out_offset_v);\n"
" vstore4(temp_1,0,output_v+out_offset_v+headDimPackV);\n"
" vstore4(temp_2,0,output_v+out_offset_v+2*headDimPackV);\n"
" vstore4(temp_3,0,output_v+out_offset_v+3*headDimPackV);\n"
" \n"
" // pastV\n"
" #ifdef SAVE_KV\n"
" vstore4(temp_0,0,past_v+past_offset_v);\n"
" vstore4(temp_1,0,past_v+past_offset_v+headDim);\n"
" vstore4(temp_2,0,past_v+past_offset_v+2*headDim);\n"
" vstore4(temp_3,0,past_v+past_offset_v+3*headDim);\n"
" #endif\n"
" }\n"
" \n"
" }\n"
"}\n"
"#ifndef MASK_DTYPE\n"
"#define MASK_DTYPE FLOAT\n"
"#define MASK_DTYPE4 FLOAT4\n"
"#endif\n"
"__kernel void rearrange_mask(GLOBAL_SIZE_3_DIMS\n"
" __global const MASK_DTYPE *input_mask,// [batch,1,seqLenQ,seqLenKV,4]\n"
" __global MASK_DTYPE *output_mask,// [batch,ROUND_UP(seqLenQ,mTileQ),ROUND_UP(seqLenKV,mTileKV)]\n"
" const int4 shape // [seqLenQ,seqLenKV,mTileQ,mTileKV]\n"
") {\n"
" const int sl=get_global_id(0); // seqLen_4\n"
" const int sl_kv=get_global_id(1); // seqLenKV_4\n"
" const int b=get_global_id(2); // Batch\n"
" DEAL_NON_UNIFORM_DIM3(sl,sl_kv,b);\n"
" \n"
" const int seq_len_pack=((shape.x+shape.z-1)/shape.z)*shape.z;\n"
" const int seq_len_kv_pack=((shape.y+shape.w-1)/shape.w)*shape.w;\n"
" int in_offset=((b*shape.x+sl*4)*shape.y+sl_kv*4);\n"
" int out_offset=(b*seq_len_pack+sl*4)*seq_len_kv_pack+sl_kv*4;\n"
" if(sl*4 >= shape.x || sl_kv*4 >= shape.y) {\n"
" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset);\n"
" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset+seq_len_kv_pack);\n"
" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset+seq_len_kv_pack*2);\n"
" vstore4((MASK_DTYPE4)0,0,output_mask+out_offset+seq_len_kv_pack*3);\n"
" } else {\n"
" int y_down_align4=(shape.y/4*4);\n"
" MASK_DTYPE4 temp_0,temp_1,temp_2,temp_3;\n"
" \n"
" if(sl_kv*4<y_down_align4) {\n"
" temp_0=vload4(0,input_mask+in_offset);\n"
" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0,input_mask+in_offset+shape.y);\n"
" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0,input_mask+in_offset+shape.y*2);\n"
" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : vload4(0,input_mask+in_offset+shape.y*3);\n"
" } else if(sl_kv*4+1 == shape.y){\n"
" temp_0=(MASK_DTYPE4)(input_mask[in_offset],0,0,0);\n"
" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y],0,0,0);//vload4(0,input_mask+in_offset+shape.y);\n"
" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*2],0,0,0);//vload4(0,input_mask+in_offset+shape.y*2);\n"
" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*3],0,0,0);//vload4(0,input_mask+in_offset+shape.y*3);\n"
" } else if(sl_kv*4+2 == shape.y){\n"
" temp_0=(MASK_DTYPE4)(input_mask[in_offset],input_mask[in_offset+1],0,0);\n"
" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : (FLOAT4)(input_mask[in_offset+shape.y],input_mask[in_offset+shape.y+1],0,0);//vload4(0,input_mask+in_offset+shape.y);\n"
" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*2],input_mask[in_offset+shape.y*2+1],0,0);//vload4(0,input_mask+in_offset+shape.y*2);\n"
" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*3],input_mask[in_offset+shape.y*3+1],0,0);//vload4(0,input_mask+in_offset+shape.y*3);\n"
" } else if(sl_kv*4+3 == shape.y){\n"
" temp_0=(MASK_DTYPE4)(input_mask[in_offset],input_mask[in_offset+1],input_mask[in_offset+2],0);\n"
" temp_1=(sl*4+1 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y],input_mask[in_offset+shape.y+1],input_mask[in_offset+shape.y+2],0);//vload4(0,input_mask+in_offset+shape.y);\n"
" temp_2=(sl*4+2 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*2],input_mask[in_offset+shape.y*2+1],input_mask[in_offset+shape.y*2+2],0);//vload4(0,input_mask+in_offset+shape.y*2);\n"
" temp_3=(sl*4+3 >= shape.x) ? (MASK_DTYPE4)0 : (MASK_DTYPE4)(input_mask[in_offset+shape.y*3],input_mask[in_offset+shape.y*3+1],input_mask[in_offset+shape.y*3+2],0);//vload4(0,input_mask+in_offset+shape.y*3);\n"
" }\n"
" vstore4(temp_0,0,output_mask+out_offset);\n"
" vstore4(temp_1,0,output_mask+out_offset+seq_len_kv_pack);\n"
" vstore4(temp_2,0,output_mask+out_offset+2*seq_len_kv_pack);\n"
" vstore4(temp_3,0,output_mask+out_offset+3*seq_len_kv_pack);\n"
" }\n"
"}\n"
"__kernel void qkv_transpose_output(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *input,// [Batch*mNumHead,ROUND_UP(mHeadDim,mTileHDN),ROUND_UP(seqLen,mTileQ)]\n"
" __global FLOAT *output,// [Batch,seqLen/4,mNumHead， mHeadDim,4]\n"
" __private const int tile_q,\n"
" __private const int tile_hdn,\n"
" __private const int seq_len,\n"
" __private const int head_num,\n"
" __private const int head_dim\n"
") {\n"
" \n"
" const int sl=get_global_id(0); // seqLen_4\n"
" const int hd=get_global_id(1); // mHeadDim_4\n"
" const int z=get_global_id(2); // Batch*mNumHead\n"
" DEAL_NON_UNIFORM_DIM3(sl,hd,z);\n"
" \n"
" const int b=z/head_num;\n"
" const int hn=z % head_num;\n"
" \n"
" const int seq_len_pack=((seq_len+tile_q-1)/tile_q)*tile_q;\n"
" const int head_dim_pack=((head_dim+tile_hdn-1)/tile_hdn)*tile_hdn;\n"
" \n"
" const int offset_inp=((b*head_num+hn)*head_dim_pack+4*hd)*seq_len_pack+4*sl;\n"
" \n"
" const int offset_out=(((b*seq_len+sl*4)*head_num+hn)*head_dim+4*hd);\n"
" \n"
" // Q\n"
" FLOAT4 temp_0=vload4(0,input+offset_inp);\n"
" FLOAT4 temp_1=vload4(0,input+offset_inp+seq_len_pack);\n"
" FLOAT4 temp_2=vload4(0,input+offset_inp+2*seq_len_pack);\n"
" FLOAT4 temp_3=vload4(0,input+offset_inp+3*seq_len_pack);\n"
" \n"
" vstore4((FLOAT4)(temp_0.s0,temp_1.s0,temp_2.s0,temp_3.s0),0,output+offset_out);\n"
" if(4*sl+1 >= seq_len) return;\n"
" vstore4((FLOAT4)(temp_0.s1,temp_1.s1,temp_2.s1,temp_3.s1),0,output+offset_out+head_num*head_dim);\n"
" if(4*sl+2 >= seq_len) return;\n"
" vstore4((FLOAT4)(temp_0.s2,temp_1.s2,temp_2.s2,temp_3.s2),0,output+offset_out+2*head_num*head_dim);\n"
" if(4*sl+3 >= seq_len) return;\n"
" vstore4((FLOAT4)(temp_0.s3,temp_1.s3,temp_2.s3,temp_3.s3),0,output+offset_out+3*head_num*head_dim);\n"
"}\n"
"#ifndef NUMHEAD_GROUP_SIZE\n"
"#define NUMHEAD_GROUP_SIZE 1\n"
"#endif\n"
"__kernel void rearrange_q(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *query,// [batch query_seq_len head_num head_dim]\n"
" __global FLOAT *query_tmp,// [batch head_num head_dim_4 query_seq_len_4]\n"
" __private const int seq_len,\n"
" __private const int head_dim,\n"
" __private const int head_num) {\n"
" /*\n"
" the kernel assume head_dim is multiple of 4.\n"
" */\n"
" const int x=get_global_id(0); // query_seq_len/4\n"
" const int y=get_global_id(1); // head_dim/4\n"
" int z=get_global_id(2);\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" const int b=z/head_num;// batch\n"
" z=z % head_num;// head_num\n"
" \n"
" const int x4=x << 2;\n"
" const int y4=y << 2;\n"
" const int seq_len4=(seq_len+3)/4*4;;\n"
" const int stride=head_num*head_dim;\n"
" int query_offset=((b*seq_len+x4)*head_num+z)*head_dim+y4;\n"
" FLOAT4 query_vec0=vload4(0,query+query_offset); query_offset += stride;\n"
" FLOAT4 query_vec1=(x4+1 >= seq_len) ? (FLOAT4)0 : vload4(0,query+query_offset); query_offset += stride;\n"
" FLOAT4 query_vec2=(x4+2 >= seq_len) ? (FLOAT4)0 : vload4(0,query+query_offset); query_offset += stride;\n"
" FLOAT4 query_vec3=(x4+3 >= seq_len) ? (FLOAT4)0 : vload4(0,query+query_offset);\n"
" \n"
" const int queryout_offset=((b*head_num+z)*head_dim+y4)*seq_len4+x4;\n"
" vstore4((FLOAT4)(query_vec0.s0,query_vec1.s0,query_vec2.s0,query_vec3.s0),0,query_tmp+queryout_offset);\n"
" vstore4((FLOAT4)(query_vec0.s1,query_vec1.s1,query_vec2.s1,query_vec3.s1),0,query_tmp+queryout_offset+seq_len4);\n"
" vstore4((FLOAT4)(query_vec0.s2,query_vec1.s2,query_vec2.s2,query_vec3.s2),0,query_tmp+queryout_offset+seq_len4+seq_len4);\n"
" vstore4((FLOAT4)(query_vec0.s3,query_vec1.s3,query_vec2.s3,query_vec3.s3),0,query_tmp+queryout_offset+seq_len4+seq_len4+seq_len4);\n"
"}\n"
"__kernel void rearrange_k(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *key,// [batch key_seq_len kv_head_num head_dim]\n"
" __global FLOAT *past_key,// [batch kv_head_num head_dim max_length]\n"
" __private const int past_len,// prefill=0,decode=past_key len\n"
" __private const int max_len,\n"
" __private const int seq_len,\n"
" __private const int kv_head_num,\n"
" __private const int head_num,\n"
" __private const int head_dim) {\n"
" \n"
" const int x=get_global_id(0); // seq_len decode=1\n"
" const int y=get_global_id(1); // head_dim\n"
" int z=get_global_id(2); //\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" const int b=z/kv_head_num;\n"
" z=z % kv_head_num;\n"
" const int y4=y << 2;\n"
" \n"
"#ifdef OPENCL_PREFILL_ATTENTION\n"
" const int x4=x << 2;\n"
" const int stride=kv_head_num*head_dim;\n"
" int key_offset=((b*seq_len+x4)*kv_head_num+z)*head_dim+y4;\n"
" FLOAT4 key_vec0=vload4(0,key+key_offset); key_offset += stride;\n"
" FLOAT4 key_vec1=(x4+1 >= seq_len) ? (FLOAT4)0 : vload4(0,key+key_offset); key_offset += stride;\n"
" FLOAT4 key_vec2=(x4+2 >= seq_len) ? (FLOAT4)0 : vload4(0,key+key_offset); key_offset += stride;\n"
" FLOAT4 key_vec3=(x4+3 >= seq_len) ? (FLOAT4)0 : vload4(0,key+key_offset);\n"
" const int output_offset=((b*kv_head_num+z)*head_dim+y4)*max_len+past_len+x4;\n"
" vstore4((FLOAT4)(key_vec0.s0,key_vec1.s0,key_vec2.s0,key_vec3.s0),0,past_key+output_offset);\n"
" vstore4((FLOAT4)(key_vec0.s1,key_vec1.s1,key_vec2.s1,key_vec3.s1),0,past_key+output_offset+max_len);\n"
" vstore4((FLOAT4)(key_vec0.s2,key_vec1.s2,key_vec2.s2,key_vec3.s2),0,past_key+output_offset+max_len+max_len);\n"
" vstore4((FLOAT4)(key_vec0.s3,key_vec1.s3,key_vec2.s3,key_vec3.s3),0,past_key+output_offset+max_len+max_len+max_len);\n"
"#else\n"
" FLOAT4 key_vec=vload4(0,key+(b*kv_head_num+z)*head_dim+y4);\n"
" const int output_offset=((b*kv_head_num+z)*head_dim+y4)*max_len+past_len;\n"
" past_key[output_offset]=key_vec.s0;\n"
" past_key[output_offset+max_len]=key_vec.s1;\n"
" past_key[output_offset+max_len+max_len]=key_vec.s2;\n"
" past_key[output_offset+max_len+max_len+max_len]=key_vec.s3;\n"
"#endif\n"
"}\n"
"__kernel void rearrange_v(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *value,// [batch value_seq_len kv_head_num head_dim]\n"
" __global FLOAT *past_value,// [batch kv_head_num max_length head_dim]\n"
" __private const int past_len,\n"
" __private const int max_len,\n"
" __private const int seq_len,\n"
" __private const int kv_head_num,\n"
" __private const int head_dim) {\n"
" \n"
" const int x=get_global_id(0); // head_dim\n"
" const int y=get_global_id(1); // seq_len decode=1\n"
" int z=get_global_id(2); // kv_head_num\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" const int b=z/kv_head_num;\n"
" z=z % kv_head_num;\n"
" const int x4=x << 2;\n"
" \n"
"#ifdef OPENCL_PREFILL_ATTENTION\n"
" const int y4=y << 2;\n"
" const int stride=kv_head_num*head_dim;\n"
" int value_offset=((b*seq_len+y4)*kv_head_num+z)*head_dim+x4;\n"
" FLOAT4 value_vec0=vload4(0,value+value_offset); value_offset += stride;\n"
" FLOAT4 value_vec1=(y4+1 >= seq_len) ? (FLOAT4)0 : vload4(0,value+value_offset); value_offset += stride;\n"
" FLOAT4 value_vec2=(y4+2 >= seq_len) ? (FLOAT4)0 : vload4(0,value+value_offset); value_offset += stride;\n"
" FLOAT4 value_vec3=(y4+3 >= seq_len) ? (FLOAT4)0 : vload4(0,value+value_offset);\n"
" const int output_offset=((b*kv_head_num+z)*max_len+past_len+y4)*head_dim+x4;\n"
" vstore4(value_vec0,0,past_value+output_offset);\n"
" vstore4(value_vec1,0,past_value+output_offset+head_dim);\n"
" vstore4(value_vec2,0,past_value+output_offset+head_dim+head_dim);\n"
" vstore4(value_vec3,0,past_value+output_offset+head_dim+head_dim+head_dim);\n"
"#else\n"
" FLOAT4 value_vec=vload4(0,value+(b*kv_head_num+z)*head_dim+x4);\n"
" const int output_offset=((b*kv_head_num+z)*max_len+past_len)*head_dim+x4;\n"
" vstore4(value_vec,0,past_value+output_offset);\n"
"#endif\n"
"}\n"
"__kernel void rearrange_mask_shortprefill(GLOBAL_SIZE_3_DIMS\n"
" #ifdef ADD_MASK\n"
" __global const FLOAT* mask,\n"
" __global FLOAT* maskout,\n"
" #else\n"
" __global const int* mask,// [1 1 query_seq_len mask_key_seq_len4]\n"
" __global int* maskout,// [1 1 mask_key_seq_len4 query_seq_len4]\n"
" #endif\n"
" __private const int query_seq_len,\n"
" __private const int mask_key_seq_len){\n"
" const int x=get_global_id(0); // query_seq_len4\n"
" const int y=get_global_id(1); // mask_key_seq_len4\n"
" const int z=get_global_id(2); // batch\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" const int x4=x << 2;\n"
" const int y4=y << 2;\n"
" float4 mask_tmp0,mask_tmp1,mask_tmp2,mask_tmp3;\n"
" float4 mask0,mask1,mask2,mask3;\n"
" int mask_offset=x4*mask_key_seq_len+y4;\n"
" if(x4+3<query_seq_len && y4+3<mask_key_seq_len){\n"
" mask_tmp0=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp1=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp2=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp3=convert_float4(vload4(0,mask+mask_offset));\n"
" } else{\n"
" if(y4+3<mask_key_seq_len){\n"
" mask_tmp0=convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset)); mask_offset += mask_key_seq_len;\n"
" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : convert_float4(vload4(0,mask+mask_offset));\n"
" } else if(y4+1 == mask_key_seq_len){\n"
" mask_tmp0=(float4)(mask[mask_offset],0,0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],0,0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],0,0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],0,0,0);\n"
" }else if(y4+2 == mask_key_seq_len){\n"
" mask_tmp0=(float4)(mask[mask_offset],mask[mask_offset+1],0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],0,0); mask_offset += mask_key_seq_len;\n"
" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],0,0);\n"
" }else if(y4+3 == mask_key_seq_len){\n"
" mask_tmp0=(float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0); mask_offset += mask_key_seq_len;\n"
" mask_tmp1=(x4+1 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0); mask_offset += mask_key_seq_len;\n"
" mask_tmp2=(x4+2 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0); mask_offset += mask_key_seq_len;\n"
" mask_tmp3=(x4+3 >= query_seq_len) ? (float4)0 : (float4)(mask[mask_offset],mask[mask_offset+1],mask[mask_offset+2],0);\n"
" }\n"
" }\n"
" mask0=(float4)(mask_tmp0.s0,mask_tmp1.s0,mask_tmp2.s0,mask_tmp3.s0);\n"
" mask1=(float4)(mask_tmp0.s1,mask_tmp1.s1,mask_tmp2.s1,mask_tmp3.s1);\n"
" mask2=(float4)(mask_tmp0.s2,mask_tmp1.s2,mask_tmp2.s2,mask_tmp3.s2);\n"
" mask3=(float4)(mask_tmp0.s3,mask_tmp1.s3,mask_tmp2.s3,mask_tmp3.s3);\n"
" \n"
" int query_seq_len4=((query_seq_len+3)/4)*4;\n"
" int output_offset=y4*query_seq_len4+x4;\n"
" #ifdef ADD_MASK\n"
" vstore4(CONVERT_FLOAT4(mask0),0,maskout+output_offset);\n"
" vstore4(CONVERT_FLOAT4(mask1),0,maskout+output_offset+query_seq_len4);\n"
" vstore4(CONVERT_FLOAT4(mask2),0,maskout+output_offset+query_seq_len4+query_seq_len4);\n"
" vstore4(CONVERT_FLOAT4(mask3),0,maskout+output_offset+query_seq_len4+query_seq_len4+query_seq_len4);\n"
" #else\n"
" vstore4(convert_int4(mask0),0,maskout+output_offset);\n"
" vstore4(convert_int4(mask1),0,maskout+output_offset+query_seq_len4);\n"
" vstore4(convert_int4(mask2),0,maskout+output_offset+query_seq_len4+query_seq_len4);\n"
" vstore4(convert_int4(mask3),0,maskout+output_offset+query_seq_len4+query_seq_len4+query_seq_len4);\n"
" #endif\n"
"}\n"
"__kernel void matmul_qk_div_mask_prefill(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *query,// [batch head_num head_dim_4 query_seq_len_4]\n"
" __global const FLOAT *past_key,// [batch kv_head_num head_dim_4 kv_max_length]\n"
" #ifdef ADD_MASK\n"
" __global const FLOAT* mask,\n"
" #elif defined(SET_MASK)\n"
" __global const int* mask,// [1 1 query_seq_len mask_key_seq_len]\n"
" #endif\n"
" __global FLOAT *qk,// [batch head_num kv_seq_length query_seq_len_4]\n"
" __private const float scale,\n"
" __private const int query_seq_len,\n"
" __private const int mask_key_seq_len,\n"
" __private const int key_seq_len,\n"
" __private const int max_len,\n"
" __private const int head_num,\n"
" __private const int head_dim) {\n"
" \n"
" const int x=get_global_id(0); // query_seq_len\n"
" const int y=get_global_id(1); // kv_seq_length\n"
" const int z=get_global_id(2); // head_num*batch\n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" \n"
" const int x4=x << 2;\n"
" const int y4=y << 2;\n"
" \n"
" const int query_seq_len4=(query_seq_len+3)/4*4;;\n"
" const int query_offset=z*head_dim*query_seq_len4+x4;\n"
" const int past_offset=(z/NUMHEAD_GROUP_SIZE)*head_dim*max_len+y4;\n"
" float4 out0=0,out1=0,out2=0,out3=0;\n"
" \n"
" for(int i=0; i<head_dim/4; ++i){\n"
" int i4=i << 2;\n"
" float4 query_vec0=convert_float4(vload4(0,query+query_offset+i4*query_seq_len4));\n"
" float4 query_vec1=convert_float4(vload4(0,query+query_offset+(i4+1)*query_seq_len4));\n"
" float4 query_vec2=convert_float4(vload4(0,query+query_offset+(i4+2)*query_seq_len4));\n"
" float4 query_vec3=convert_float4(vload4(0,query+query_offset+(i4+3)*query_seq_len4));\n"
" \n"
" float4 past_vec0=convert_float4(vload4(0,past_key+past_offset+i4*max_len));\n"
" float4 past_vec1=convert_float4(vload4(0,past_key+past_offset+(i4+1)*max_len));\n"
" float4 past_vec2=convert_float4(vload4(0,past_key+past_offset+(i4+2)*max_len));\n"
" float4 past_vec3=convert_float4(vload4(0,past_key+past_offset+(i4+3)*max_len));\n"
" out0=mad((float4)past_vec0.s0,query_vec0,out0);\n"
" out0=mad((float4)past_vec1.s0,query_vec1,out0);\n"
" out0=mad((float4)past_vec2.s0,query_vec2,out0);\n"
" out0=mad((float4)past_vec3.s0,query_vec3,out0);\n"
" \n"
" out1=mad((float4)past_vec0.s1,query_vec0,out1);\n"
" out1=mad((float4)past_vec1.s1,query_vec1,out1);\n"
" out1=mad((float4)past_vec2.s1,query_vec2,out1);\n"
" out1=mad((float4)past_vec3.s1,query_vec3,out1);\n"
" \n"
" out2=mad((float4)past_vec0.s2,query_vec0,out2);\n"
" out2=mad((float4)past_vec1.s2,query_vec1,out2);\n"
" out2=mad((float4)past_vec2.s2,query_vec2,out2);\n"
" out2=mad((float4)past_vec3.s2,query_vec3,out2);\n"
" \n"
" out3=mad((float4)past_vec0.s3,query_vec0,out3);\n"
" out3=mad((float4)past_vec1.s3,query_vec1,out3);\n"
" out3=mad((float4)past_vec2.s3,query_vec2,out3);\n"
" out3=mad((float4)past_vec3.s3,query_vec3,out3);\n"
" }\n"
" out0 *= (float4)scale;\n"
" out1 *= (float4)scale;\n"
" out2 *= (float4)scale;\n"
" out3 *= (float4)scale;\n"
" {\n"
" #if defined(ADD_MASK) || defined(SET_MASK)\n"
" int query_seq_len4=((query_seq_len+3)/4)*4;\n"
" int mask_clp=y4+mask_key_seq_len-key_seq_len;\n"
" int mask_offset=mask_clp*query_seq_len4+x4;\n"
" float4 mask0=mask_clp >= 0 && mask_clp<mask_key_seq_len ? convert_float4(vload4(0,mask+mask_offset)) : 0; mask_offset += query_seq_len4;\n"
" float4 mask1=mask_clp+1 >= 0 && mask_clp+1<mask_key_seq_len? convert_float4(vload4(0,mask+mask_offset)) : 0; mask_offset += query_seq_len4;\n"
" float4 mask2=mask_clp+2 >= 0 && mask_clp+2<mask_key_seq_len? convert_float4(vload4(0,mask+mask_offset)) : 0; mask_offset += query_seq_len4;\n"
" float4 mask3=mask_clp+3 >= 0 && mask_clp+3<mask_key_seq_len? convert_float4(vload4(0,mask+mask_offset)) : 0;\n"
" #endif\n"
" \n"
" #ifdef ADD_MASK\n"
" out0 += mask0;\n"
" out1 += mask1;\n"
" out2 += mask2;\n"
" out3 += mask3;\n"
" #elif defined(SET_MASK)\n"
" out0=(mask0 == (float4)0) ? (float4)(-FLT_MAX) : out0;\n"
" out1=(mask1 == (float4)0) ? (float4)(-FLT_MAX) : out1;\n"
" out2=(mask2 == (float4)0) ? (float4)(-FLT_MAX) : out2;\n"
" out3=(mask3 == (float4)0) ? (float4)(-FLT_MAX) : out3;\n"
" #endif\n"
" }\n"
" \n"
" const int qk_offset=(z*key_seq_len+y4)*query_seq_len4+x4;\n"
" vstore4(CONVERT_FLOAT4(out0),0,qk+qk_offset);\n"
" if(y4+1 >= key_seq_len) return;\n"
" vstore4(CONVERT_FLOAT4(out1),0,qk+qk_offset+query_seq_len4);\n"
" if(y4+2 >= key_seq_len) return;\n"
" vstore4(CONVERT_FLOAT4(out2),0,qk+qk_offset+query_seq_len4+query_seq_len4);\n"
" if(y4+3 >= key_seq_len) return;\n"
" vstore4(CONVERT_FLOAT4(out3),0,qk+qk_offset+query_seq_len4+query_seq_len4+query_seq_len4);\n"
"}\n"
"__kernel void matmul_qk_decode(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *query,// key [1 head_num head_dim]\n"
" __global const FLOAT *past_key,// [1 head_num head_dim max_length]\n"
" __global FLOAT *qk,// [1 head_num key_seq_len 1]\n"
" __private const float scale,\n"
" __private const int seq_len,\n"
" __private const int max_len,\n"
" __private const int head_num,\n"
" __private const int head_dim) {\n"
" \n"
" const int x=get_global_id(0); // key_seq_len\n"
" const int y=get_global_id(1); // head_num\n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" const int x4=x << 2;\n"
" \n"
" const int query_offset=y*head_dim;\n"
" const int past_offset=(y/NUMHEAD_GROUP_SIZE)*head_dim*max_len+x4;\n"
" float4 out0=0;\n"
" \n"
" for(int i=0; i<head_dim/4; ++i){\n"
" int i4=i << 2;\n"
" float4 query_vec=convert_float4(vload4(0,query+query_offset+i4));\n"
" \n"
" float4 past_vec0=convert_float4(vload4(0,past_key+past_offset+i4*max_len));\n"
" float4 past_vec1=convert_float4(vload4(0,past_key+past_offset+(i4+1)*max_len));\n"
" float4 past_vec2=convert_float4(vload4(0,past_key+past_offset+(i4+2)*max_len));\n"
" float4 past_vec3=convert_float4(vload4(0,past_key+past_offset+(i4+3)*max_len));\n"
" \n"
" out0=mad((float4)query_vec.s0,past_vec0,out0);\n"
" out0=mad((float4)query_vec.s1,past_vec1,out0);\n"
" out0=mad((float4)query_vec.s2,past_vec2,out0);\n"
" out0=mad((float4)query_vec.s3,past_vec3,out0);\n"
" }\n"
" out0 *= (float4)scale;\n"
" const int qk_offset=y*seq_len+x4;\n"
" if(x4+3<seq_len){\n"
" vstore4(CONVERT_FLOAT4(out0),0,qk+qk_offset);\n"
" }else {\n"
" int remain=seq_len-x4;\n"
" if(remain == 3){\n"
" vstore3(CONVERT_FLOAT3((float3)(out0.s012)),0,qk+qk_offset);\n"
" } else if(remain == 2){\n"
" vstore2(CONVERT_FLOAT2((float2)(out0.s01)),0,qk+qk_offset);\n"
" }else if(remain == 1){\n"
" qk[qk_offset]=out0.s0;\n"
" }\n"
" }\n"
"}\n"
"__kernel void matmul_qkv_prefill(GLOBAL_SIZE_3_DIMS\n"
" __global const FLOAT *qk,// qk prefill [batch head_num kv_seq_length query_seq_len]\n"
" __global const FLOAT *past_value,// [batch kv_head_num max_len head_dim]\n"
" __global FLOAT *output,// [batch query_seq_len head_num head_dim]\n"
" __private const int query_seq_len,\n"
" __private const int kv_seq_len,\n"
" __private const int max_len,\n"
" __private const int head_num,\n"
" __private const int kv_head_num,\n"
" __private const int head_dim) {\n"
" \n"
" const int x=get_global_id(0); // head_dim\n"
" const int y=get_global_id(1); // query_seq_len\n"
" int z=get_global_id(2); // head_num*batch\n"
" \n"
" DEAL_NON_UNIFORM_DIM3(x,y,z);\n"
" const int b=z/head_num;\n"
" z=z % head_num;\n"
" const int x8=x << 3;\n"
" const int y4=y << 2;\n"
" \n"
" const int query_seq_len4=(query_seq_len+3)/4*4;\n"
" const int qk_offset=(b*head_num+z)*kv_seq_len*query_seq_len4+y4;\n"
" const int past_offset=((b*kv_head_num+z/NUMHEAD_GROUP_SIZE)*max_len)*head_dim+x8;\n"
" const int loop_end=max(kv_seq_len/4-1,0);\n"
" COMPUTE_FLOAT8 out0=0,out1=0,out2=0,out3=0;\n"
" \n"
" for(int i=0; i<loop_end; ++i){\n"
" int i4=i << 2;\n"
" COMPUTE_FLOAT4 qk_vec0=CONVERT_COMPUTE_FLOAT4(vload4(0,qk+qk_offset+i4*query_seq_len4));\n"
" COMPUTE_FLOAT4 qk_vec1=CONVERT_COMPUTE_FLOAT4(vload4(0,qk+qk_offset+(i4+1)*query_seq_len4));\n"
" COMPUTE_FLOAT4 qk_vec2=CONVERT_COMPUTE_FLOAT4(vload4(0,qk+qk_offset+(i4+2)*query_seq_len4));\n"
" COMPUTE_FLOAT4 qk_vec3=CONVERT_COMPUTE_FLOAT4(vload4(0,qk+qk_offset+(i4+3)*query_seq_len4));\n"
" \n"
" COMPUTE_FLOAT8 past_vec0=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+i4*head_dim));\n"
" COMPUTE_FLOAT8 past_vec1=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i4+1)*head_dim));\n"
" COMPUTE_FLOAT8 past_vec2=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i4+2)*head_dim));\n"
" COMPUTE_FLOAT8 past_vec3=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i4+3)*head_dim));\n"
" \n"
" out0=mad((COMPUTE_FLOAT8)qk_vec0.s0,past_vec0,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec1.s0,past_vec1,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec2.s0,past_vec2,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec3.s0,past_vec3,out0);\n"
" \n"
" out1=mad((COMPUTE_FLOAT8)qk_vec0.s1,past_vec0,out1);\n"
" out1=mad((COMPUTE_FLOAT8)qk_vec1.s1,past_vec1,out1);\n"
" out1=mad((COMPUTE_FLOAT8)qk_vec2.s1,past_vec2,out1);\n"
" out1=mad((COMPUTE_FLOAT8)qk_vec3.s1,past_vec3,out1);\n"
" \n"
" out2=mad((COMPUTE_FLOAT8)qk_vec0.s2,past_vec0,out2);\n"
" out2=mad((COMPUTE_FLOAT8)qk_vec1.s2,past_vec1,out2);\n"
" out2=mad((COMPUTE_FLOAT8)qk_vec2.s2,past_vec2,out2);\n"
" out2=mad((COMPUTE_FLOAT8)qk_vec3.s2,past_vec3,out2);\n"
" \n"
" out3=mad((COMPUTE_FLOAT8)qk_vec0.s3,past_vec0,out3);\n"
" out3=mad((COMPUTE_FLOAT8)qk_vec1.s3,past_vec1,out3);\n"
" out3=mad((COMPUTE_FLOAT8)qk_vec2.s3,past_vec2,out3);\n"
" out3=mad((COMPUTE_FLOAT8)qk_vec3.s3,past_vec3,out3);\n"
" }\n"
" for(int i=(loop_end << 2); i<kv_seq_len; ++i){\n"
" COMPUTE_FLOAT4 qk_vec=CONVERT_COMPUTE_FLOAT4(vload4(0,qk+qk_offset+i*query_seq_len4));\n"
" COMPUTE_FLOAT8 past_vec=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+i*head_dim));\n"
" \n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s0,past_vec,out0);\n"
" out1=mad((COMPUTE_FLOAT8)qk_vec.s1,past_vec,out1);\n"
" out2=mad((COMPUTE_FLOAT8)qk_vec.s2,past_vec,out2);\n"
" out3=mad((COMPUTE_FLOAT8)qk_vec.s3,past_vec,out3);\n"
" }\n"
" \n"
" const int output_offset=((b*query_seq_len+y4)*head_num+z)*head_dim+x8;\n"
" const int stride=head_num*head_dim;\n"
" vstore8(CONVERT_FLOAT8(out0),0,output+output_offset);\n"
" if(y4+1 >= query_seq_len) return;\n"
" vstore8(CONVERT_FLOAT8(out1),0,output+output_offset+stride);\n"
" if(y4+2 >= query_seq_len) return;\n"
" vstore8(CONVERT_FLOAT8(out2),0,output+output_offset+stride+stride);\n"
" if(y4+3 >= query_seq_len) return;\n"
" vstore8(CONVERT_FLOAT8(out3),0,output+output_offset+stride+stride+stride);\n"
"}\n"
"__kernel void matmul_qkv_decode_b8(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *qk,// qk [1 head_num qk_seq_len 1]\n"
" __global const FLOAT *past_value,// [1 head_num max_len head_dim]\n"
" __global FLOAT *output,// [1 1 head_num head_dim]\n"
" __private const int qk_seq_len,\n"
" __private const int max_len,\n"
" __private const int head_num,\n"
" __private const int kv_head_num,\n"
" __private const int head_dim) {\n"
" \n"
" const int x=get_global_id(0); // head_dim\n"
" const int y=get_global_id(1); // head_num\n"
" \n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" const int x8=x << 3;\n"
" \n"
" const int qk_offset=y*qk_seq_len;\n"
" const int past_offset=((y/NUMHEAD_GROUP_SIZE)*max_len)*head_dim+x8;\n"
" COMPUTE_FLOAT8 out0=0;\n"
" #ifdef LOOP_UNROLL_4\n"
" const int loop_end=max((qk_seq_len+3)/4-1,0);\n"
" for(int i=0; i<loop_end; ++i){\n"
" int i4=i << 2;\n"
" COMPUTE_FLOAT4 qk_vec=CONVERT_COMPUTE_FLOAT4(vload4(0,qk+qk_offset+i4));\n"
" \n"
" COMPUTE_FLOAT8 past_vec0=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+i4*head_dim));\n"
" COMPUTE_FLOAT8 past_vec1=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i4+1)*head_dim));\n"
" COMPUTE_FLOAT8 past_vec2=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i4+2)*head_dim));\n"
" COMPUTE_FLOAT8 past_vec3=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i4+3)*head_dim));\n"
" \n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s0,past_vec0,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s1,past_vec1,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s2,past_vec2,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s3,past_vec3,out0);\n"
" }\n"
" for(int i=(loop_end << 2); i<qk_seq_len; ++i){\n"
" COMPUTE_FLOAT qk_vec=qk[qk_offset+i];\n"
" COMPUTE_FLOAT8 past_vec=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+i*head_dim));\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec,past_vec,out0);\n"
" }\n"
" #elif (defined LOOP_UNROLL_8)\n"
" const int loop_end=max((qk_seq_len+7)/8-1,0);\n"
" for(int i=0; i<loop_end; ++i){\n"
" int i8=i << 3;\n"
" COMPUTE_FLOAT8 qk_vec=CONVERT_COMPUTE_FLOAT8(vload8(0,qk+qk_offset+i8));\n"
" \n"
" COMPUTE_FLOAT8 past_vec0=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+i8*head_dim));\n"
" COMPUTE_FLOAT8 past_vec1=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i8+1)*head_dim));\n"
" COMPUTE_FLOAT8 past_vec2=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i8+2)*head_dim));\n"
" COMPUTE_FLOAT8 past_vec3=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i8+3)*head_dim));\n"
" COMPUTE_FLOAT8 past_vec4=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i8+4)*head_dim));\n"
" COMPUTE_FLOAT8 past_vec5=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i8+5)*head_dim));\n"
" COMPUTE_FLOAT8 past_vec6=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i8+6)*head_dim));\n"
" COMPUTE_FLOAT8 past_vec7=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+(i8+7)*head_dim));\n"
" \n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s0,past_vec0,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s1,past_vec1,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s2,past_vec2,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s3,past_vec3,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s4,past_vec4,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s5,past_vec5,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s6,past_vec6,out0);\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec.s7,past_vec7,out0);\n"
" }\n"
" for(int i=(loop_end << 3); i<qk_seq_len; ++i){\n"
" COMPUTE_FLOAT qk_vec=qk[qk_offset+i];\n"
" COMPUTE_FLOAT8 past_vec=CONVERT_COMPUTE_FLOAT8(vload8(0,past_value+past_offset+i*head_dim));\n"
" out0=mad((COMPUTE_FLOAT8)qk_vec,past_vec,out0);\n"
" }\n"
" #endif\n"
" \n"
" const int output_offset=y*head_dim+x8;\n"
" vstore8(CONVERT_FLOAT8(out0),0,output+output_offset);\n"
"}\n"
"__kernel void matmul_qkv_decode_b4(GLOBAL_SIZE_2_DIMS\n"
" __global const FLOAT *qk,// qk [1 head_num qk_seq_len 1]\n"
" __global const FLOAT *past_value,// [1 head_num max_len head_dim]\n"
" __global FLOAT *output,// [1 1 head_num head_dim]\n"
" __private const int qk_seq_len,\n"
" __private const int max_len,\n"
" __private const int head_num,\n"
" __private const int kv_head_num,\n"
" __private const int head_dim) {\n"
" \n"
" const int x=get_global_id(0); // head_dim\n"
" const int y=get_global_id(1); // head_num\n"
" \n"
" DEAL_NON_UNIFORM_DIM2(x,y);\n"
" const int x4=x << 2;\n"
" \n"
" const int qk_offset=y*qk_seq_len;\n"
" const int past_offset=((y/NUMHEAD_GROUP_SIZE)*max_len)*head_dim+x4;\n"
" COMPUTE_FLOAT4 out0=0;\n"
" #ifdef LOOP_UNROLL_4\n"
" const int loop_end=max((qk_seq_len+3)/4-1,0);\n"
" for(int i=0; i<loop_end; ++i){\n"
" int i4=i << 2;\n"
" COMPUTE_FLOAT4 qk_vec=CONVERT_COMPUTE_FLOAT4(vload4(0,qk+qk_offset+i4));\n"
" \n"
" COMPUTE_FLOAT4 past_vec0=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+i4*head_dim));\n"
" COMPUTE_FLOAT4 past_vec1=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+(i4+1)*head_dim));\n"
" COMPUTE_FLOAT4 past_vec2=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+(i4+2)*head_dim));\n"
" COMPUTE_FLOAT4 past_vec3=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+(i4+3)*head_dim));\n"
" \n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s0,past_vec0,out0);\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s1,past_vec1,out0);\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s2,past_vec2,out0);\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s3,past_vec3,out0);\n"
" }\n"
" for(int i=(loop_end << 2); i<qk_seq_len; ++i){\n"
" COMPUTE_FLOAT qk_vec=qk[qk_offset+i];\n"
" COMPUTE_FLOAT4 past_vec=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+i*head_dim));\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec,past_vec,out0);\n"
" }\n"
" #elif (defined LOOP_UNROLL_8)\n"
" const int loop_end=max((qk_seq_len+7)/8-1,0);\n"
" for(int i=0; i<loop_end; ++i){\n"
" int i8=i << 3;\n"
" COMPUTE_FLOAT8 qk_vec=CONVERT_COMPUTE_FLOAT8(vload8(0,qk+qk_offset+i8));\n"
" \n"
" COMPUTE_FLOAT4 past_vec0=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+i8*head_dim));\n"
" COMPUTE_FLOAT4 past_vec1=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+(i8+1)*head_dim));\n"
" COMPUTE_FLOAT4 past_vec2=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+(i8+2)*head_dim));\n"
" COMPUTE_FLOAT4 past_vec3=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+(i8+3)*head_dim));\n"
" COMPUTE_FLOAT4 past_vec4=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+(i8+4)*head_dim));\n"
" COMPUTE_FLOAT4 past_vec5=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+(i8+5)*head_dim));\n"
" COMPUTE_FLOAT4 past_vec6=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+(i8+6)*head_dim));\n"
" COMPUTE_FLOAT4 past_vec7=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+(i8+7)*head_dim));\n"
" \n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s0,past_vec0,out0);\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s1,past_vec1,out0);\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s2,past_vec2,out0);\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s3,past_vec3,out0);\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s4,past_vec4,out0);\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s5,past_vec5,out0);\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s6,past_vec6,out0);\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec.s7,past_vec7,out0);\n"
" }\n"
" for(int i=(loop_end << 3); i<qk_seq_len; ++i){\n"
" COMPUTE_FLOAT qk_vec=qk[qk_offset+i];\n"
" COMPUTE_FLOAT4 past_vec=CONVERT_COMPUTE_FLOAT4(vload4(0,past_value+past_offset+i*head_dim));\n"
" out0=mad((COMPUTE_FLOAT4)qk_vec,past_vec,out0);\n"
" }\n"
" #endif\n"
" \n"
" const int output_offset=y*head_dim+x4;\n"
" vstore4(CONVERT_FLOAT4(out0),0,output+output_offset);\n"
"}\n"
;
#endif
}
